import torch
import copy
import re
import numpy as np
import math
def extract_number(filename, order=0):
    numbers = re.findall(r'\d+', filename)
    return int(numbers[order]) if numbers else 0
def tensorlist_to_tensor(weights):
    return torch.cat([w.view(w.numel()) if w.dim() > 1 else w for w in weights])

def get_diff_weights(weights1, weights2):
    """ get 'weights2'-'weights1' as a 1-dimensional tensor """
    return torch.cat([w.view(w.numel()) if w.dim() > 1 else w for w in [w2 - w1 for (w1, w2) in zip(weights1, weights2)]])

def set_weights(model, weights, directions=None, step=None):
    if directions is None:
        for (p, w) in zip(model.parameters(), weights):
            p.data.copy_(w.type(type(p.data)))
    else:
        assert step is not None, 'If a direction is specified then step must be specified as well'
        dx = vec_to_tensorlist(directions[0], model.parameters())
        dy = vec_to_tensorlist(directions[1], model.parameters())
        changes = [d0*step[0] + d1*step[1] for (d0, d1) in zip(dx, dy)]

        for (p, w, d) in zip(model.parameters(), weights, changes):
            p.data = w + torch.tensor(d).type_as(w)
    return model

def vec_to_tensorlist(direction, params):
    """
    convert both np and tensor vector to tensorlist
    """
    if isinstance(params, list):
        w2 = copy.deepcopy(params)
        idx = 0
        for w in w2:
            w.copy_(torch.tensor(direction[idx:idx + w.numel()]).view(w.size()))
            idx += w.numel()
        assert(idx == len(direction))
        return w2
    else:
        s2 = []
        idx = 0
        for w in params:
            s2.append(torch.tensor(direction[idx:idx + w.numel()]).view(w.size()))
            idx += w.numel()
        assert(idx == len(direction))
        return s2
    
def write_tensorlist(f, name, origin):
    grp = f.create_group(name)
    for i, l in enumerate(origin):
        if isinstance(l, torch.Tensor):
            l = l.detach().cpu().numpy()
        grp.create_dataset(str(i), data=l)


def get_weights(net):
    """an instance of generator net.parameters() can only be used once"""
    return [p.data for p in net.parameters()]

def read_nplist(f, name):
    """ Read group with name as the key from the hdf5 file and return a list numpy vectors. """
    grp = f[name]
    return [grp[str(i)] for i in range(len(grp))]
def read_tensorlist(f, name, local_rank):
    """ Read group with name as the key from the hdf5 file and return a list tensor vectors. """
    grp = f[name]
    return [ torch.tensor(grp[str(i)]).to(f"cuda:{local_rank}") for i in range(len(grp))]


# def read_tensorlist(f, name, local_rank):
#     """ Read group with name as the key from the hdf5 file and return a list of tensors. """
#     grp = f[name]
    
#     # Step 1: 先将所有的数据从HDF5读取到NumPy数组中
#     tensors = [grp[str(i)] for i in range(len(grp))]
    
#     # Step 2: 将 NumPy 数组转换为 PyTorch Tensor，并移动到目标 GPU
#     tensors = [torch.from_numpy(tensor).to(f"cuda:{local_rank}") for tensor in tensors]

#     return tensors

def seconds2days_hours_minutes_seconds(elapsed_time_in_seconds):
    days = elapsed_time_in_seconds // 86400
    hours = (elapsed_time_in_seconds % 86400) // 3600
    minutes = (elapsed_time_in_seconds % 3600) // 60
    seconds = elapsed_time_in_seconds % 60
    if days !=0:
        return f"{int(days)} days, {int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds"
    elif hours !=0:
        return f"{int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds"
    elif minutes !=0:
        return f"{int(minutes)} minutes, {int(seconds)} seconds"
    else:
        return f"{int(seconds)} seconds"
# def batch_enumerate(elements, n, fill_value=float('nan')):
#     batch_num = 0  # 记录批次的编号
#     for i in range(0, len(elements), n):
#         batch_elements = elements[i:i + n]  # 当前批次的元素（未补齐）
#         batch_indices = list(range(i, i + len(batch_elements)))  # 当前批次的索引（未补齐）
        
#         # 生成补齐版本
#         padded_elements = list(batch_elements)
#         padded_indices = list(batch_indices)
#         if len(batch_elements) < n:
#             pad_size = n - len(batch_elements)
#             padded_elements += [fill_value] * pad_size
#             padded_indices += [fill_value] * pad_size
        
#         # 返回当前批次的编号、未补齐和已补齐的索引与元素
#         yield batch_num, batch_indices, batch_elements, padded_indices, padded_elements
        
#         batch_num += 1  # 批次编号递增

def batch_enumerate(elements, n):
    batch_num = 0  # 记录批次的编号
    for i in range(0, len(elements), n):
        batch_elements = elements[i:i + n]  # 当前批次的元素（未补齐）
        batch_indices = list(range(i, i + len(batch_elements)))  # 当前批次的索引（未补齐）
                
        # 返回当前批次的编号、未补齐和已补齐的索引与元素
        yield batch_num, batch_indices, batch_elements
        
        batch_num += 1  # 批次编号递增


# set number of streams for each bundle and rank
def get_stream_length_func(bundle_num, max_task_num, inds_nums, bundle_size):
    def streams_by_rank(rank):
        if bundle_num <= (math.ceil(max_task_num/bundle_size)-2):
            return bundle_size
        else:
            return (inds_nums[rank]-1)%bundle_size+1
    return streams_by_rank


def auto_scale(trajectories, margin_ratio):
    # 初始化整体最小值和最大值为正负无穷
    overall_min = np.array([np.inf, np.inf])  # [x_min, y_min]
    overall_max = np.array([-np.inf, -np.inf])  # [x_max, y_max]
    
    # 遍历每一个轨迹
    for key, arr in trajectories.items():
        if arr.shape[0] == 2:  # 确保数组形状是 (2, n)
            max_vals = np.max(arr, axis=1)  # 计算每列的最大值 (x_max, y_max)
            min_vals = np.min(arr, axis=1)  # 计算每列的最小值 (x_min, y_min)
            
            # 更新整体最小值和最大值，乘以1.1扩展范围
            overall_min = np.minimum(overall_min, min_vals) 
            overall_max = np.maximum(overall_max, max_vals) 
        else:
            print(f"Warning: Array for key '{key}' does not have shape (2, n). Skipping.")
    
    # 解包整体的最小值和最大值
    xmi, ymi = overall_min
    xma, yma = overall_max
    xmin = xmi-margin_ratio*(xma-xmi)
    xmax = xma+margin_ratio*(xma-xmi)
    ymin = ymi-margin_ratio*(yma-ymi)
    ymax = yma+margin_ratio*(yma-ymi)

    return xmin, xmax, ymin, ymax


def normalize_direction(direction, weights, norm='filter'):
    if norm == 'filter':
        # Rescale the filters (weights in group) in 'direction' so that each
        # filter has the same norm as its corresponding filter in 'weights'.
        for d, w in zip(direction, weights):
            d.mul_(w.norm()/(d.norm() + 1e-10))
    elif norm == 'layer':
        # Rescale the layer variables in the direction so that each layer has
        # the same norm as the layer variables in weights.
        direction.mul_(weights.norm()/direction.norm())
    elif norm == 'weight':
        # Rescale the entries in the direction so that each entry has the same
        # scale as the corresponding weight.
        direction.mul_(weights)
    elif norm == 'dfilter':
        # Rescale the entries in the direction so that each filter direction
        # has the unit norm.
        for d in direction:
            d.div_(d.norm() + 1e-10)
    elif norm == 'dlayer':
        # Rescale the entries in the direction so that each layer direction has
        # the unit norm.
        direction.div_(direction.norm())


def normalize_directions_for_weights(direction, weights, norm='filter', ignore='biasbn'):
    assert(len(direction) == len(weights))
    for d, w in zip(direction, weights):
        if d.dim() <= 1:
            if ignore == 'biasbn':
                d.fill_(0) # ignore directions for weights with 1 dimension
            else:
                d.copy_(w) # keep directions for weights/bias that are only 1 per node
        else:
            normalize_direction(d, w, norm)

# import torch
# import numpy as np

# class PCA:
#     def __init__(self, n_components=None):
#         self.n_components = n_components  # 保留的主成分数量
#         self.components_ = None           # 主成分，初始化为空

#     def fit(self, X):
#         """
#         在输入数据 X 上训练 PCA 模型
#         :param X: 输入数据，大小为 [n_samples, n_features]
#         :return: self，支持链式调用
#         """
#         # 将输入数据转换为 tensor
#         X = torch.tensor(X, dtype=torch.float32)

#         # 去均值处理
#         X_centered = X - X.mean(dim=0)

#         # 计算协方差矩阵
#         cov_matrix = torch.cov(X_centered.T)

#         # 特征分解：得到特征值和特征向量
#         eigenvalues, eigenvectors = torch.linalg.eigh(cov_matrix)

#         # 按照特征值从大到小排序
#         sorted_indices = torch.argsort(eigenvalues, descending=True)
#         eigenvectors_sorted = eigenvectors[:, sorted_indices]
#         eigenvalues_sorted = eigenvalues[sorted_indices]

#         # 如果指定了主成分数目 n_components，则选择前 n_components 个特征向量
#         if self.n_components is not None:
#             eigenvectors_sorted = eigenvectors_sorted[:, :self.n_components]
#             eigenvalues_sorted = eigenvalues_sorted[:self.n_components]

#         # 保存主成分和解释方差
#         self.components_ = eigenvectors_sorted  # 主成分

#         return self  # 返回 PCA 对象，支持链式调用

#     def transform(self, X):
#         """
#         使用训练好的 PCA 模型将数据投影到主成分上
#         :param X: 输入数据，大小为 [n_samples, n_features]
#         :return: 投影后的数据，大小为 [n_samples, n_components]
#         """
#         X = torch.tensor(X, dtype=torch.float32)
#         X_centered = X - X.mean(dim=0)
#         return torch.matmul(X_centered, self.components_)

#     def fit_transform(self, X):
#         """
#         训练 PCA 模型并将数据投影到主成分上
#         :param X: 输入数据，大小为 [n_samples, n_features]
#         :return: 投影后的数据，大小为 [n_samples, n_components]
#         """
#         self.fit(X)
#         return self.transform(X)
    
    
def check_params_equal(weights1, weights2):
    for w1, w2 in zip(weights1, weights2):
        if not torch.equal(w1, w2):
            print("The first two trajectories in the trajectories file do not have same origin!")



